# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import math
from collections.abc import Iterable

import torch
import torch.nn as nn

from vllm.config import VllmConfig
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.vocab_parallel_embedding import (
    ParallelLMHead,
    VocabParallelEmbedding,
)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader

from .utils import maybe_prefix

SQRT2 = 2**0.5


class MLPSpeculatorLayerNorm(nn.Module):
    """
    A L2 normalization implementation
    ...
    Args
    ----
    normalized_shape : int
        Dimensionality of input data (size of final tensor axis)
    eps : float
        Safety term to prevent division by zero. Make sure the chosen value
         fits in the range of your encoding scheme
         (i.e. fp16 requires eps >= 6e-8).
    elementwise_scale_and_shift : bool
        Include a learned scaling and shift term after normalization.
    """

    def __init__(
        self,
        normalized_shape,
        eps=1e-06,
        elementwise_scale_and_shift=True,
    ):
        super().__init__()
        self.elementwise_scale_and_shift = elementwise_scale_and_shift
        if self.elementwise_scale_and_shift:
            self.weight = nn.Parameter(torch.empty(normalized_shape))
            self.bias = nn.Parameter(torch.empty(normalized_shape))
        self.eps = eps

    def forward(self, x):
        xf = x
        xf = xf * torch.rsqrt(xf.pow(2).mean(-1, keepdim=True) + self.eps)
        x = xf.type_as(x)
        if self.elementwise_scale_and_shift:
            x = self.weight * x
            x = x + self.bias
        return x


class MLPSpeculator(nn.Module):
    """
    An implementation of the speculative models introduced in
    "Accelerating Production LLMs with Combined Token/Embedding
    Speculators"
    https://arxiv.org/pdf/2404.19124

    Trained speculators of this type are available on HF hub at:
    https://huggingface.co/ibm-ai-platform and https://huggingface.co/ibm-granite
    """

    def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
        super().__init__()
        config = vllm_config.model_config.hf_config
        self.n_predict = config.n_predict
        self.vocab_size = config.vocab_size
        self.emb_dim = config.emb_dim
        self.inner_dim = config.inner_dim if config.inner_dim != 0 else config.emb_dim

        self.max_speculative_tokens = config.num_lookahead_tokens

        self.tie_weights = config.tie_weights
        self.scale_input = config.scale_input

        if self.tie_weights:
            assert self.n_predict > 1, (
                "You cannot tie weights between stages when only 1 exists"
            )
            embedding = VocabParallelEmbedding(
                config.vocab_size, self.inner_dim, org_num_embeddings=config.vocab_size
            )
            self.emb = nn.ModuleList([embedding] * self.max_speculative_tokens)

            # the initial projection from the base model may
            # have a different size, so that stays separate.
            proj_first = nn.Linear(self.emb_dim, self.inner_dim, bias=False)
            proj_tied = nn.Linear(self.inner_dim, self.inner_dim, bias=False)
            self.proj = nn.ModuleList(
                [proj_first] + [proj_tied] * (self.max_speculative_tokens - 1)
            )

            self.head = nn.ModuleList(
                [
                    ParallelLMHead(
                        self.vocab_size,
                        self.inner_dim,
                        bias=False,
                        prefix=maybe_prefix(prefix, f"head.{i}"),
                    )
                    for i in range(self.max_speculative_tokens)
                ]
            )

            ln = MLPSpeculatorLayerNorm(
                self.inner_dim, elementwise_scale_and_shift=True
            )
            self.ln = nn.ModuleList([ln] * self.max_speculative_tokens)

        else:
            self.emb = nn.ModuleList(
                [
                    VocabParallelEmbedding(
                        config.vocab_size,
                        self.inner_dim,
                    )
                    for _ in range(self.max_speculative_tokens)
                ]
            )

            self.proj = nn.ModuleList(
                [
                    nn.Linear(
                        (self.emb_dim if i == 0 else self.inner_dim),
                        self.inner_dim,
                        bias=False,
                    )
                    for i in range(self.max_speculative_tokens)
                ]
            )

            self.head = nn.ModuleList(
                [
                    ParallelLMHead(
                        self.vocab_size,
                        self.inner_dim,
                        bias=False,
                        prefix=maybe_prefix(prefix, f"head.{i}"),
                    )
                    for i in range(self.max_speculative_tokens)
                ]
            )
            self.ln = nn.ModuleList(
                [
                    MLPSpeculatorLayerNorm(
                        self.inner_dim, elementwise_scale_and_shift=True
                    )
                    for _ in range(self.max_speculative_tokens)
                ]
            )
        if self.scale_input:
            self.ln0 = MLPSpeculatorLayerNorm(
                self.emb_dim, elementwise_scale_and_shift=False
            )

        self.state_weight = 0.5 ** (0.5 / config.n_predict)
        self.emb_weight = math.sqrt((1 - self.state_weight**2) * (self.inner_dim / 2))
        self.activation = nn.GELU()
        self.config = config
        self.logits_processor = LogitsProcessor(
            config.vocab_size, config.vocab_size, 1.0
        )

    # NOTE(woosuk): This method is commented out because it is old code
    # using V0. We should either port it to V1 or remove it.

    # def generate_proposals(
    #     self,
    #     input_ids: torch.Tensor,
    #     previous_hidden_states: torch.Tensor,
    #     num_predict_tokens: int,
    #     sampling_metadata: SamplingMetadata,
    # ) -> list[SamplerOutput]:
    #     if num_predict_tokens > self.max_speculative_tokens:
    #         raise ValueError(f"Max speculative tokens for model is "
    #                          f"{self.max_speculative_tokens}, but "
    #                          f"{num_predict_tokens} were requested")

    #     # b x 1 x d
    #     previous_hidden_states = previous_hidden_states.unsqueeze(1)

    #     if self.scale_input:
    #         previous_hidden_states = self.ln0(previous_hidden_states) / SQRT2

    #     # b x 1
    #     last_tokens = input_ids.unsqueeze(1)

    #     next_tokens = []

    #     for head_index in range(num_predict_tokens):

    #         # Project and predict
    #         z = self.emb[head_index](last_tokens)  # b k d
    #         states = self.proj[head_index](previous_hidden_states)

    #         # Weighted add of state_weight*state and emb_weight*z
    #         # Let subsequent LN take care of denominator
    #         # state_weight is close to 1, so shouldn't be any precision issues
    #         states.add_(z, alpha=self.emb_weight / self.state_weight)

    #         states = self.activation(self.ln[head_index](states))  # b k d
    #         previous_hidden_states = states
    #         # TODO: not yet supporting top_k_tokens_per_head
    #         states = states.flatten(0, 1)

    #         logits = self.logits_processor(self.head[head_index], states,
    #                                        sampling_metadata)

    #         output = self.sampler(logits, sampling_metadata)
    #         last_tokens = output.sampled_token_ids
    #         next_tokens.append(output)

    #     return next_tokens

    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
        params_dict = dict(self.named_parameters())
        loaded_params: set[str] = set()
        for name, loaded_weight in weights:
            name = name.replace("speculator.", "")
            param = params_dict.get(name)
            if param is not None:
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
                weight_loader(param, loaded_weight)
                loaded_params.add(name)
        return loaded_params
